import lpips
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import torch
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

def save_coordinate(points:torch.Tensor, filename):
    points = points.numpy()
    plt.figure(figsize=(9, 9))
    plt.scatter(points[0, :], points[1, :], s=1)
    plt.axis('off')
    plt.gca().set_aspect('equal', adjustable='box')
    plt.savefig(filename)

def save_image(image:torch.Tensor, filename, transpose = False):
    image = image.numpy()
    image = (image - image.min()) / (image.max() - image.min())
    image = (image * 255).astype(np.uint8)
    if transpose :
        image = np.transpose(image, (1, 2, 0))
    image = Image.fromarray(image)
    image.save(filename)
    
def calculate_metric(ref_img, sample):
    # Data range : [0, 1]
    ref_img = ref_img / 2 + 0.5
    
    ref_img_np = ref_img.permute(1, 2, 0).cpu().numpy()
    sample_np = sample.permute(1, 2, 0).cpu().detach().numpy()
    
    # Calculate SSIM
    ssim_value = ssim(ref_img_np, sample_np, data_range = ref_img_np.max() - ref_img_np.min(), channel_axis = 2)
    
    print("SSIM:", ssim_value)
    
    # Calculate PSNR & LPIPS
    psnr_value = psnr(ref_img_np, sample_np, data_range=ref_img_np.max() - ref_img_np.min())
    loss_fn = lpips.LPIPS(net='alex')
    lpips_value = loss_fn.forward(ref_img, sample)
    
    print("PSNR:", psnr_value)
    print("LPIPS:", lpips_value.item())
    
    return ssim_value, psnr_value, lpips_value.item()